{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dspy\n",
    "from dsp.utils import deduplicate\n",
    "from dspy.datasets import HotPotQA\n",
    "from dspy.predict.retry import Retry\n",
    "from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch\n",
    "from dspy.evaluate.evaluate import Evaluate\n",
    "\n",
    "from dspy.primitives.assertions import assert_transform_module, backtrack_handler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "openai.api_key = os.getenv('OPENAI_API_KEY')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')\n",
    "dspy.settings.configure(rm=colbertv2_wiki17_abstracts)\n",
    "turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=500)\n",
    "dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0)\n",
    "trainset = [x.with_inputs('question') for x in dataset.train]\n",
    "devset = [x.with_inputs('question') for x in dataset.dev]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Suggestion helper functions and Teleprompter metric\n",
    "\n",
    "def validate_query_distinction_local(previous_queries, query):\n",
    "    \"\"\"check if query is distinct from previous queries\"\"\"\n",
    "    if previous_queries == []:\n",
    "        return True\n",
    "    if dspy.evaluate.answer_exact_match_str(query, previous_queries, frac=0.8):\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "\n",
    "def validate_context_and_answer_and_hops(example, pred, trace=None):\n",
    "    if not dspy.evaluate.answer_exact_match(example, pred):\n",
    "        return False\n",
    "\n",
    "    if not dspy.evaluate.answer_passage_match(example, pred):\n",
    "        return False\n",
    "\n",
    "    return True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extrinsic metrics\n",
    "\n",
    "def gold_passages_retrieved(example, pred, trace=None):\n",
    "    gold_titles = set(map(dspy.evaluate.normalize_text, example['gold_titles']))\n",
    "    found_titles = set(map(dspy.evaluate.normalize_text, [c.split(' | ')[0] for c in pred.context]))\n",
    "\n",
    "    return gold_titles.issubset(found_titles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# signatures of dspy modules\n",
    "class GenerateAnswer(dspy.Signature):\n",
    "    \"\"\"Answer questions with short factoid answers.\"\"\"\n",
    "\n",
    "    context = dspy.InputField(desc=\"may contain relevant facts\")\n",
    "    question = dspy.InputField()\n",
    "    answer = dspy.OutputField(desc=\"often between 1 and 5 words\")\n",
    "\n",
    "\n",
    "class GenerateSearchQuery(dspy.Signature):\n",
    "    \"\"\"Write a simple search query that will help answer a complex question.\"\"\"\n",
    "\n",
    "    context = dspy.InputField(desc=\"may contain relevant facts\")\n",
    "    question = dspy.InputField()\n",
    "    query = dspy.OutputField()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def all_queries_distinct(prev_queries):\n",
    "    query_distinct = True\n",
    "    for i, query in enumerate(prev_queries):\n",
    "        if validate_query_distinction_local(prev_queries[:i], query) == False:\n",
    "            query_distinct = False\n",
    "            break\n",
    "    return query_distinct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimplifiedBaleen(dspy.Module):\n",
    "    def __init__(self, passages_per_hop=2, max_hops=2):\n",
    "        super().__init__()\n",
    "\n",
    "        self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]\n",
    "        self.retrieve = dspy.Retrieve(k=passages_per_hop)\n",
    "        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n",
    "        self.max_hops = max_hops\n",
    "\n",
    "        # for evaluating assertions only\n",
    "        self.passed_suggestions = 0\n",
    "\n",
    "    def forward(self, question):\n",
    "        context = []\n",
    "        prev_queries = [question]\n",
    "\n",
    "        for hop in range(self.max_hops):\n",
    "            query = self.generate_query[hop](context=context, question=question).query\n",
    "            prev_queries.append(query)\n",
    "            passages = self.retrieve(query).passages\n",
    "            context = deduplicate(context + passages)\n",
    "        \n",
    "        if all_queries_distinct(prev_queries):\n",
    "            self.passed_suggestions += 1\n",
    "        \n",
    "        pred = self.generate_answer(context=context, question=question)\n",
    "        pred = dspy.Prediction(context=context, answer=pred.answer)\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimplifiedBaleenAssertions(dspy.Module):\n",
    "    def __init__(self, passages_per_hop=2, max_hops=2):\n",
    "        super().__init__()\n",
    "        self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]\n",
    "        self.retrieve = dspy.Retrieve(k=passages_per_hop)\n",
    "        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n",
    "        self.max_hops = max_hops\n",
    "\n",
    "        # for evaluating assertions only\n",
    "        self.passed_suggestions = 0\n",
    "\n",
    "    def forward(self, question):\n",
    "        context = []\n",
    "        prev_queries = [question]\n",
    "\n",
    "        for hop in range(self.max_hops):\n",
    "            query = self.generate_query[hop](context=context, question=question).query\n",
    "\n",
    "            dspy.Suggest(\n",
    "                len(query) <= 100,\n",
    "                \"Query should be short and less than 100 characters\",\n",
    "            )\n",
    "\n",
    "            dspy.Suggest(\n",
    "                validate_query_distinction_local(prev_queries, query),\n",
    "                \"Query should be distinct from: \"\n",
    "                + \"; \".join(f\"{i+1}) {q}\" for i, q in enumerate(prev_queries)),\n",
    "            )\n",
    "\n",
    "            prev_queries.append(query)\n",
    "            passages = self.retrieve(query).passages\n",
    "            context = deduplicate(context + passages)\n",
    "        \n",
    "        if all_queries_distinct(prev_queries):\n",
    "            self.passed_suggestions += 1\n",
    "\n",
    "        pred = self.generate_answer(context=context, question=question)\n",
    "        pred = dspy.Prediction(context=context, answer=pred.answer)\n",
    "        return pred\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(module):\n",
    "    module.passed_suggestions = 0\n",
    "\n",
    "    retrieval_score = evaluate_on_hotpotqa(\n",
    "        module, metric=gold_passages_retrieved\n",
    "    )\n",
    "    \n",
    "    suggestions_score = module.passed_suggestions / len(devset) * 100\n",
    "\n",
    "    accuracy_score = evaluate_on_hotpotqa(\n",
    "        module, metric=dspy.evaluate.answer_exact_match\n",
    "    )\n",
    "\n",
    "    print(f\"## Suggestions Score: {suggestions_score}\")\n",
    "\n",
    "    print(f\"## Retrieval Score: {retrieval_score}\")\n",
    "    print(f\"## Accuracy Score: {accuracy_score}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# No Compilation + No Assertion\n",
    "baleen = SimplifiedBaleen()\n",
    "evaluate(baleen)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# No Compilation + Yes Assertion\n",
    "baleen_with_assertions = assert_transform_module(SimplifiedBaleenAssertions().map_named_predictors(Retry), backtrack_handler) \n",
    "evaluate(baleen_with_assertions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_bootstrapped_demos = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Yes Compilation + No Assertion\n",
    "baleen = SimplifiedBaleen()\n",
    "teleprompter = BootstrapFewShotWithRandomSearch(\n",
    "    metric=validate_context_and_answer_and_hops,\n",
    "    max_bootstrapped_demos=max_bootstrapped_demos,\n",
    "    num_candidate_programs=6,\n",
    ")\n",
    "\n",
    "compiled_baleen = teleprompter.compile(student = SimplifiedBaleen(), teacher = SimplifiedBaleen(), trainset = trainset, valset = devset)\n",
    "evaluate(compiled_baleen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Yes Compilation + Yes Assertion\n",
    "baleen = SimplifiedBaleen()\n",
    "teleprompter = BootstrapFewShotWithRandomSearch(\n",
    "    metric=validate_context_and_answer_and_hops,\n",
    "    max_bootstrapped_demos=max_bootstrapped_demos,\n",
    "    num_candidate_programs=6,\n",
    ")\n",
    "compiled_baleen = teleprompter.compile(\n",
    "    student=assert_transform_module(\n",
    "        SimplifiedBaleenAssertions().map_named_predictors(Retry),\n",
    "        backtrack_handler,\n",
    "    ),\n",
    "    teacher=baleen,\n",
    "    trainset=trainset,\n",
    "    valset=devset\n",
    ")\n",
    "evaluate(compiled_baleen)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
